{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adjusted Mutual Information Tutorial:\n",
    "In this tutorial, we will go over how to run the adjusted_mutual.py script. The purpose of the script is to allow for the calculation of the Adjusted Mutual Information between sets of atlases. This script outputs a png heatmap and csv file containing the AMI values."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First the relevant libraries must be imported. adjusted_mutual utilizes the nibabel, numpy, glob, sklearn, argpars, and matplotlib libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nibabel as nb\n",
    "import numpy as np\n",
    "import glob\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn import metrics as skm\n",
    "from argparse import ArgumentParser"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we have to specify the inputs. For the purpose of this jupyter notebook sys.argv is used to store the inputs, but when running the script in the terminal:\n",
    "\n",
    "python $<$path_to_script$>$/adjusted_mutual.py $<$input_dir$>$ --output_dir $<$output_dir$>$ --fig_name $<$fig_name$>$ --voxel_size $<$vox_size$>$ --atlases $<$atlas1$>$ $<$atlas2$>$ $<$atlas3$>$ ...\n",
    "    \n",
    "input_dir = The path to the dirctory containing the atlases you intend to analyze\n",
    "\n",
    "output_dir = The path for the directory where you intend to have the output files saved\n",
    "\n",
    "fig_name = Name of the output png and csv files containing the adjusted mutual information heatmap. If not specified, then the files will be called \"AMI_Matrix\"\n",
    "\n",
    "voxel_size = Voxel_size for atlases to be analyzed, this will filter the files located in the input_dir folder for anything with <atlas_name>\\_res-<VOX>x<VOX>x<VOX>.nii.gz. Default is 1, meaning any file ending in \\_res-1x1x1.nii.gz will be selected if atlases are specified.\n",
    "\n",
    "atlases = List of atlas names to analyze, located in the directory specified by input_dir. If not specified, then all atlases in the input directory that meet the standards set by voxel_size will be analyzed. If specified, these override the value specified in '--voxel_size'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dir = '/Users/ross/Documents/neuroparc/atlases/label/Human'\n",
    "output_dir = '/Users/ross/Documents/neuroparc/atlases'\n",
    "fig_name = 'AMI_analysis'\n",
    "voxel_size = '1'\n",
    "atlases = ['AAL_space-MNI152NLin6_res-1x1x1.nii.gz','AICHAJoliot2015_space-MNI152NLin6_res-1x1x1.nii.gz']\n",
    "\n",
    "#Necessary for running this function in a jupyter notebook\n",
    "sys.argv = ['',input_dir, '--output_dir',output_dir,'--fig_name',fig_name,'--voxel_size',voxel_size, '--atlases',atlases[0],atlases[1]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we have to define the adjusted_mutual_info function, which is where the actual AMI calculation occurs. This uses the sklearn.metrics.adjusted_mutual_info_score function to calculate the value and returns it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def adjusted_mutual_info(atlas1, atlas2):\n",
    "    \"\"\"Calculate adjusted mutual information between two atlases\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    atlas1 : str\n",
    "        path to the first atlas being analyzed\n",
    "    atlas2: str\n",
    "        path to the second atlas being analyzed\n",
    "    \"\"\"\n",
    "    \n",
    "    #Load in the atlas raw data\n",
    "    at1 = nb.load(atlas1)\n",
    "    at2 = nb.load(atlas2)\n",
    "\n",
    "    atlas1 = at1.get_data()\n",
    "    atlas2 = at2.get_data()\n",
    "    \n",
    "    #Flatten both matricies into a vector and feed into the function\n",
    "    AMI = skm.adjusted_mutual_info_score(atlas1.flatten(),atlas2.flatten())\n",
    "\n",
    "    #Return resulting value\n",
    "    return AMI\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now the main function needs to be defined and called, which is where the png and csv files are created from the results of adjusted_mutual_info. The inputs specified by the user are taken in and the atlases are iterated through to generate an AMI value for every possible combination of atlases.\n",
    "\n",
    "If you wish to change the format of the png file or csv file, you are able to change any of the formating code after the comment \"Save AMI matrix to csv file, comma delimited\" without compromising any of the calculations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    parser = ArgumentParser(\n",
    "        description=\"Script to calculate the adjsted mutual information of atlases\"\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"input_dir\",\n",
    "        help=\"\"\"The directory where the atlases you wish to analyze are saved.\"\"\",\n",
    "        action=\"store\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--output_dir\",\n",
    "        help=\"\"\"Directory to save the output adjacency matrix. If not specified, then\n",
    "        the input directory will be used. Default is None.\"\"\",\n",
    "        action=\"store\",\n",
    "        default=None,\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--fig_name\",\n",
    "        help = \"\"\"Name to use for the output png and csv files. If not specified, then\n",
    "        the name 'AMI_Matrix' will be used. Default is 'AMI_Matrix'\"\"\",\n",
    "        action = \"store\",\n",
    "        default = 'AMI_Matrix',\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--voxel_size\",\n",
    "        help=\"\"\"Voxel_size for atlases to be analyzed, this will filter\n",
    "        the files located in the input_dir file for anything with \n",
    "        <atlas_name>_res-<VOX>x<VOX>x<VOX>.nii.gz. Default is 1.\"\"\",\n",
    "        action=\"store\",\n",
    "        default=\"1\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--atlases\",\n",
    "        help=\"\"\"List of atlas names to analyze, located in the directory\n",
    "        specified by input_dir. If not specified, then all atlases in the input\n",
    "        directory will be analyzed. These override the value specified in\n",
    "        '--voxel_size'. Default is None.\"\"\",\n",
    "        nargs=\"+\",\n",
    "    )\n",
    "\n",
    "    # ------- Parse CLI arguments ------- #\n",
    "    result = parser.parse_args()\n",
    "    input_dir = result.input_dir\n",
    "    output_dir=result.output_dir\n",
    "    fig_name = result.fig_name\n",
    "    voxel_size = result.voxel_size\n",
    "    atlases = result.atlases\n",
    "    \n",
    "\n",
    "    # Save outputs to input directory if output directory not specified\n",
    "    if not output_dir:\n",
    "        output_dir=input_dir\n",
    "\n",
    "    #Search for all applicable files\n",
    "    if atlases:\n",
    "        #Append input directory to atlas_names\n",
    "        atlas_paths = [f\"{input_dir}/{i}\" for i in atlases]\n",
    "    else:\n",
    "        dims = f\"{voxel_size}x{voxel_size}x{voxel_size}\"\n",
    "        \n",
    "        atlas_paths = [\n",
    "        i for i in glob.glob(input_dir + f\"/*{dims}.nii.gz\") if dims in i\n",
    "        ]\n",
    "\n",
    "\n",
    "    #Create a ndarray of zeros to be filled in\n",
    "    AMI_array = np.zeros((len(atlas_paths)+1,len(atlas_paths)+1))\n",
    "\n",
    "    #Loop through all combinations of atlases specified and calculate AMI\n",
    "    for i in range(len(atlas_paths)):\n",
    "        for j in range(len(atlas_paths)):\n",
    "            if i >= j:\n",
    "                AMI = adjusted_mutual_info(atlas_paths[i],atlas_paths[j])\n",
    "                AMI_array[int(i)][int(j)]=float(AMI)\n",
    "                AMI_array[int(j)][int(i)]=float(AMI)\n",
    "\n",
    "    #Save AMI matrix to csv file, comma delimited\n",
    "    np.savetxt(f'{output_dir}/{fig_name}.csv', AMI_array, delimiter=\",\")\n",
    "\n",
    "    #Generate labels for figure\n",
    "    for i in range(len(atlases)):\n",
    "        atlases[i] = atlases[i].split('_space-')[0]\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "    im = ax.imshow(AMI_array, cmap=\"gist_heat_r\") #Can specify the colorscheme you wish to use\n",
    "    ax.set_xticks(np.arange(len(atlases)))\n",
    "    ax.set_yticks(np.arange(len(atlases)))\n",
    "\n",
    "    ax.set_xticklabels(atlases)\n",
    "    ax.set_yticklabels(atlases)\n",
    "\n",
    "    #Label x and y-axis, adjust fontsize as necessary\n",
    "    plt.setp(ax.get_xticklabels(), fontsize=6, rotation=90, ha=\"right\", va=\"center\", rotation_mode=\"anchor\")\n",
    "    plt.setp(ax.get_yticklabels(), fontsize=6)\n",
    "\n",
    "    plt.colorbar(im, aspect=30)\n",
    "    ax.set_title(\"Adjusted Mutual Information Between Atlases\")\n",
    "    \n",
    "    fig.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "    #Save the figure\n",
    "    plt.savefig(f'{output_dir}/{fig_name}.png', dpi=1000)\n",
    "\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Outputs:\n",
    "After running this script to completion, for every potential combination of atlases you should have the following in your output directory:\n",
    "1. A file named {figname}.png or AMI_Matrix.png, containing a heatmap of the AMI values for each combination of atlases.\n",
    "2. A file named {figname}.csv or AMI_Matrix.csv, containing the AMI matrix information for the acccompanying png file.\n",
    "\n",
    "## Common Errors:\n",
    "- Issues may arise if atlases being compared have different voxel sizes, as the overlap measured by the dice coefficient may not be accurate\n",
    "- If the atlases you are using do not end with '{vox}\\_{vox}\\_{vox}.nii.gz', issues will arise if you are using the --voxel_size approach to selecting atlases to analyze. Either specify the exact atlases you wish to analyze, rename the atlases to follow the intended structure, or edit the first if/else statement in the main function."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.3 64-bit",
   "language": "python",
   "name": "python37364bit7aebd514ba8143f191d4fea51945eb11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
